#ifndef __NN_LOSS_H__
#define __NN_LOSS_H__
#include "header.hpp"
#include "tensor.hpp"
#include <iostream>
namespace nn {
class LossBase {
private:
public:
  LossBase()= default;;
  ~LossBase()= default;;
  virtual Dmatrix grad(const Dmatrix &, const Dmatrix &) = 0;
  virtual double loss(const Dmatrix &, const Dmatrix &) = 0;
};

class SoftmaxCrossEntropyLoss : public LossBase {
private:
  double T{};
  Dmatrix _weight;

public:
  inline SoftmaxCrossEntropyLoss() noexcept : LossBase(){};
  inline ~SoftmaxCrossEntropyLoss() noexcept = default;;
  Dmatrix grad(const Dmatrix &prediction, const Dmatrix &targets) override {
    auto grads = prediction;
    grads -= targets;
    return grads / prediction.rows();
  }

  double loss(const Dmatrix &prediction, const Dmatrix &targets)override {

    auto pre_ret = prediction;
    auto maxcoeff = pre_ret.rowwise().maxCoeff();
    for (auto col : pre_ret.colwise()) {
      auto tmp = col - maxcoeff;
      col = tmp.array().exp();
    }

    auto rowsum = pre_ret.rowwise().sum();
    for (auto col : pre_ret.colwise()) {
      auto tmp = col.array() / rowsum.array();
      col = tmp;
    }
    // std::cout << pre_ret<<std::endl;
    // std::cout << targets <<std::endl;
    // pre_ret*targets;
    // pre_ret.transposeInPlace();
    auto pre_ret_trans = pre_ret.transpose();
    auto nll = -(pre_ret_trans * targets).rowwise().sum();
    return (nll.array().log() * -1).sum() / prediction.rows();
  }
}; // class SoftmaxCrossEntropyLoss
} // namespace nn

#endif